syntax = "proto3";

package tensorflow.data;

import "tensorflow/core/data/service/common.proto";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/protobuf/data_service.proto";
import "tensorflow/core/protobuf/snapshot.proto";

// Next tag: 3
message ActiveTask {
  int64 task_id = 1;
  // Estimated time it takes this Task to produce an element, in nanoseconds.
  double processing_time_nsec = 2;
}

// Next tag: 9
message WorkerHeartbeatRequest {
  string worker_address = 1;
  repeated DataTransferServerInfo transfer_servers = 7;
  repeated string worker_tags = 4;
  // The UID of the worker Borg job, used for telemetry.
  int64 worker_uid = 5;
  repeated int64 current_tasks = 2;
  // The status of any active snapshot tasks, keyed by snapshot path.
  map<string, SnapshotTaskProgress> snapshot_task_progress = 6;
  reserved 3;
  // TODO(armandouv): Deprecate current_tasks and extract task ids from here.
  repeated ActiveTask active_tasks = 8;
}

// Next tag: 4
message WorkerHeartbeatResponse {
  repeated TaskDef new_tasks = 1;
  repeated int64 tasks_to_delete = 2;
  // Snapshots to process.
  repeated SnapshotTaskDef snapshot_tasks = 3;
}

// Next tag: 3
message TaskProgress {
  // The task that this message is about.
  int64 task_id = 1;
  // Whether the task has completed.
  bool completed = 2;
}

// Next tag: 3
message WorkerUpdateRequest {
  string worker_address = 1;
  repeated TaskProgress updates = 2;
}

// Next tag: 1
message WorkerUpdateResponse {}

// Next tag: 2
message GetDatasetDefRequest {
  string dataset_id = 1;
}

// Next tag: 2
message GetDatasetDefResponse {
  DatasetDef dataset_def = 1;
}

// Next tag: 4
message GetSplitRequest {
  int64 iteration_id = 1;
  int64 repetition = 2;
  int64 split_provider_index = 3;
}

// Next tag: 3
message GetSplitResponse {
  TensorProto split = 1;
  bool end_of_splits = 2;
}

// Next tag: 1
message GetVersionRequest {}

// Next tag: 2
message GetVersionResponse {
  int64 version = 1;
}

// Next tag: 5
message GetOrRegisterDatasetRequest {
  // The dataset to register.
  DatasetDef dataset = 1;
  // Metadata related to tf.data service.
  DataServiceMetadata metadata = 3;
  oneof optional_dataset_id {
    // If provided, tf.data service will register the dataset with the specified
    // ID. Otherwise, it will generate a unique dataset ID.
    string dataset_id = 4;
  }
  reserved 2;
}

// Next tag: 2
message GetOrRegisterDatasetResponse {
  // The id for the registered dataset.
  string dataset_id = 1;
}

// Next tag: 2
message GetDataServiceMetadataRequest {
  // The dataset id to get the data service dataset metadata.
  string dataset_id = 1;
}

// Next tag: 2
message GetDataServiceMetadataResponse {
  // The retrieved data service dataset metadata.
  DataServiceMetadata metadata = 1;
}
// Next tag: 1
message GetDataServiceConfigRequest {}

// Next tag: 2
message GetDataServiceConfigResponse {
  DataServiceConfig config = 1;
}

// Next tag: 7
message GetOrCreateJobRequest {
  // The id of the dataset to create a job for.
  string dataset_id = 1;
  // A mode controlling how the tf.data service produces data for the job.
  ProcessingModeDef processing_mode_def = 2;
  // Optional job name identifying a shared job. If not set, the RPC will always
  // create a new job.
  oneof optional_job_name {
    string job_name = 3;
  }
  // Optional number of consumers. If set, the job's tasks will provide
  // their elements to consumers round-robin.
  oneof optional_num_consumers {
    int64 num_consumers = 4;
  }
  // True if cross-trainer cache is enabled.
  bool use_cross_trainer_cache = 5;
  // Specifies which workers the client of this job reads from.
  TargetWorkers target_workers = 6;
}

// Next tag: 2
message GetOrCreateJobResponse {
  int64 job_id = 1;
}

// Next tag: 3
message GetOrCreateIterationRequest {
  // The job to create an iteration for.
  int64 job_id = 1;
  // Which repetition of the job to read from.
  int64 repetition = 2;
}

// Next tag: 2
message GetOrCreateIterationResponse {
  // An id for the client that will read from the iteration. When the client is
  // done with the iteration, they should call ReleaseIterationClient with this
  // id.
  int64 iteration_client_id = 1;
}

// Next tag: 4
message MaybeRemoveTaskRequest {
  int64 task_id = 1;
  int64 consumer_index = 2;
  int64 round = 3;
}

// Next tag: 2
message MaybeRemoveTaskResponse {
  bool removed = 1;
}

// Next tag: 2
message ReleaseIterationClientRequest {
  int64 iteration_client_id = 1;
}

// Next tag: 1
message ReleaseIterationClientResponse {}

// Next tag: 6
message ClientHeartbeatRequest {
  reserved 3;
  // The iteration client id to heartbeat for.
  int64 iteration_client_id = 1;
  // Reports which round the client is currently reading from when doing
  // round-robin reads.
  oneof optional_current_round {
    int64 current_round = 2;
  }
  // Reports whether the client has successfully blocked the indicated round
  // from starting. This enables the dispatcher to add a new task in the
  // blocked round or later.
  oneof optional_blocked_round {
    int64 blocked_round = 4;
  }
  // Target processing time in nanoseconds observed by the client.
  double target_processing_time_nsec = 5;
}

// Next tag: 5
message ClientHeartbeatResponse {
  // A list of all tasks that the client should read from.
  repeated TaskInfo task_info = 1;
  // Tells the client not to start the given round if possible.
  oneof optional_block_round {
    int64 block_round = 3;
  }
  // Whether the iteration has finished.
  bool iteration_finished = 2;
  // tf.data service deployment mode. Supported values are "REMOTE",
  // "COLOCATED", and "HYBRID". If unspecified, it is assumed to be "REMOTE".
  DeploymentMode deployment_mode = 4;
}

// Next tag: 3
message WorkerInfo {
  string address = 1;
  reserved 2;
}

// Next tag: 1
message GetWorkersRequest {}

// Next tag: 2
message GetWorkersResponse {
  // A list of all workers.
  repeated WorkerInfo workers = 1;
}

// Next tag: 4
message SnapshotRequest {
  // The dataset to snapshot.
  DatasetDef dataset = 1;

  // The path to which to materialize the snapshot.
  string path = 2;

  // The metadata for the snapshot.
  experimental.DistributedSnapshotMetadata metadata = 3;
}

// Next tag: 1
message SnapshotResponse {}

// Next tag: 6
message GetSnapshotSplitRequest {
  // The address of the worker requesting the split.
  string worker_address = 4;

  // The base path of the snapshot materialization.
  string base_path = 1;

  // The index of the snapshot stream from which to get the split.
  int64 stream_index = 2;

  // The index of the dataset source from which to get the split.
  int64 source_index = 3;

  // The repetition of the dataset from which to get the split.
  int64 repetition_index = 5;
}

// Next tag: 4
message GetSnapshotSplitResponse {
  oneof response {
    // The split to process.
    TensorProto split = 1;

    // If true, there are no splits left to be processed for this stream.
    bool end_of_splits = 2;
  }

  // The local split index within the stream source, starting at zero and
  // incrementing by one for each split assigned to the worker. This local index
  // is used by the worker to keep track of which split it has read up to. If
  // `end_of_splits` is true, this equals the total number of splits.
  int64 local_split_index = 3;
}

// Next tag: 2
message GetSnapshotStreamsRequest {
  // The path at which the snapshot is being materialized.
  string path = 1;
}

// Next tag: 2
message GetSnapshotStreamsResponse {
  // Information about all streams for the snapshot.
  repeated SnapshotStreamInfo streams = 1;
}

// Next tag: 4
message DisableCompressionAtRuntimeRequest {
  string dataset_id = 1;

  // Information about the calling trainer. If the dispatcher has already
  // decided whether or not to disable compression at runtime for `dataset_id`,
  // this information is ignored. If the dispatcher has not yet decided, this
  // information is used to decide.
  oneof trainer_info {
    // If set, the calling trainer is eligible for runtime compression
    // disabling, and its host has properties described by this field, which
    // will be used by the dispatcher to decide whether or not to disable
    // compression.
    string trainer_compression_info = 2;

    // If `true`, the calling trainer is ineligible for runtime compression, and
    // the dispatcher will decide to not disable compression.
    bool trainer_is_ineligible = 3;
  }
}

// Next tag: 4
message DisableCompressionAtRuntimeResponse {
  oneof decision {
    // If `true`, the dispatcher has decided to disable compression, and workers
    // and trainers will ignore the compression specified at registration time.
    // If `false`, the dispatcher has decided to not disable compression, and
    // workers and trainers will adhere to the compression specified at
    // registration time.
    bool compression_disabled_at_runtime = 1;

    // If `true`, the dispatcher has not yet made a decision as to whether or
    // not to disable compression at runtime.
    bool not_enough_information = 2;

    // If `true`, compression was not specified at registration time, so there
    // is no runtime compression disabling decision to make.
    bool no_compression_to_disable = 3;
  }
}

service DispatcherService {
  // Performs a periodic worker heartbeat.
  rpc WorkerHeartbeat(WorkerHeartbeatRequest) returns (WorkerHeartbeatResponse);

  // Updates the dispatcher with information about the worker's state.
  rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);

  // Gets a dataset definition.
  rpc GetDatasetDef(GetDatasetDefRequest) returns (GetDatasetDefResponse);

  // Gets the next split for a given iteration.
  rpc GetSplit(GetSplitRequest) returns (GetSplitResponse);

  // Returns the API version of the server.
  rpc GetVersion(GetVersionRequest) returns (GetVersionResponse);

  // Registers a dataset with the server, or returns its id if it is already
  // registered.
  //
  // The dataset is constructed in a new graph, so it must not refer to
  // external resources or variables.
  rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
      returns (GetOrRegisterDatasetResponse);

  // Gets a job if it already exists, otherwise creates it.
  rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse);

  // Gets an iteration if it already exists, otherwise creates it.
  rpc GetOrCreateIteration(GetOrCreateIterationRequest)
      returns (GetOrCreateIterationResponse);

  // Attempts to remove a task from a round-robin read iteration.
  rpc MaybeRemoveTask(MaybeRemoveTaskRequest) returns (MaybeRemoveTaskResponse);

  // Releases an iteration client so that an iteration may eventually be cleaned
  // up.
  rpc ReleaseIterationClient(ReleaseIterationClientRequest)
      returns (ReleaseIterationClientResponse);

  // Heartbeats from the client. This lets the dispatcher know that the client
  // is still active, and gives the dispatcher a chance to notify the client
  // of new tasks.
  rpc ClientHeartbeat(ClientHeartbeatRequest) returns (ClientHeartbeatResponse);

  // Reports a list of all workers registered with the dispatcher.
  rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);

  // Returns the data service metadata for the registered dataset.
  rpc GetDataServiceMetadata(GetDataServiceMetadataRequest)
      returns (GetDataServiceMetadataResponse);

  // Returns the config of a data service cluster.
  rpc GetDataServiceConfig(GetDataServiceConfigRequest)
      returns (GetDataServiceConfigResponse);

  // Initiates the process of materializing a dataset's output to disk.
  rpc Snapshot(SnapshotRequest) returns (SnapshotResponse);

  // Gets the next split for the given stream of the given snapshot. Returns an
  // error if there has been some miscommunication between the worker and
  // dispatcher regarding stream assignment and the worker should stop (though
  // due to stream leases this case should never happen).
  rpc GetSnapshotSplit(GetSnapshotSplitRequest)
      returns (GetSnapshotSplitResponse);

  // Returns information about all streams for the given snapshot.
  rpc GetSnapshotStreams(GetSnapshotStreamsRequest)
      returns (GetSnapshotStreamsResponse);

  // Returns information about the decision to disable compression at runtime
  // for the given dataset.
  rpc DisableCompressionAtRuntime(DisableCompressionAtRuntimeRequest)
      returns (DisableCompressionAtRuntimeResponse);
}
